Skip to content

Conversation

yizhang2077
Copy link
Collaborator

@yizhang2077 yizhang2077 commented Oct 4, 2025

Motivation

ref #10438. add radix cache for mamba, we will implement page_size > 1 and Marconi soon

Co-authored-by: hanming-lu [email protected]
Co-authored-by: hzh0425 [email protected]
Co-authored-by: thalahors [email protected]

Modifications

ref: doc

Accuracy Tests

python3 -m sglang.launch_server --model Qwen/Qwen3-Next-80B-A3B-Instruct --tp 4  --chunked-prefill-size 64 --max-running-requests 16

python3 benchmark/gsm8k/bench_sglang.py --num-question 1000
Accuracy: 0.949
Invalid: 0.000
Latency: 346.597 s
Output throughput: 485.867 token/s

Benchmarking and Profiling

# multi-turn benchmark

python3 -m sglang.bench_serving --backend sglang --dataset-name generated-shared-prefix --gsp-num-groups 50 --gsp-prompts-per-group 10 --gsp-system-prompt-len 10240 --gsp-question-len 256 --gsp-output-len 128 --max-concurrency 5  --port 30000

# without radix cache
 python3 -m sglang.launch_server --model Qwen/Qwen3-Next-80B-A3B-Instruct --tp 4 --disable-radix-cache

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 5         
Successful requests:                     500       
Benchmark duration (s):                  189.37    
Total input tokens:                      5521483   
Total generated tokens:                  64000     
Total generated tokens (retokenized):    63980     
Request throughput (req/s):              2.64      
Input token throughput (tok/s):          29157.63  
Output token throughput (tok/s):         337.97    
Total token throughput (tok/s):          29495.60  
Concurrency:                             5.00      
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   1891.89   
Median E2E Latency (ms):                 1884.63   
---------------Time to First Token----------------
Mean TTFT (ms):                          890.59    
Median TTFT (ms):                        1065.17   
P99 TTFT (ms):                           1180.80   
---------------Inter-Token Latency----------------
Mean ITL (ms):                           7.90      
Median ITL (ms):                         6.31      
P95 ITL (ms):                            6.57      
P99 ITL (ms):                            8.14      
Max ITL (ms):                            751.16    
==================================================

# with radix cache
 python3 -m sglang.launch_server --model Qwen/Qwen3-Next-80B-A3B-Instruct --tp 4

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 5         
Successful requests:                     500       
Benchmark duration (s):                  141.38    
Total input tokens:                      5521483   
Total generated tokens:                  64000     
Total generated tokens (retokenized):    63985     
Request throughput (req/s):              3.54      
Input token throughput (tok/s):          39055.43  
Output token throughput (tok/s):         452.69    
Total token throughput (tok/s):          39508.12  
Concurrency:                             5.00      
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   1412.49   
Median E2E Latency (ms):                 1337.07   
---------------Time to First Token----------------
Mean TTFT (ms):                          507.33    
Median TTFT (ms):                        489.45    
P99 TTFT (ms):                           1336.39   
---------------Inter-Token Latency----------------
Mean ITL (ms):                           7.14      
Median ITL (ms):                         6.30      
P95 ITL (ms):                            6.85      
P99 ITL (ms):                            8.87      
Max ITL (ms):                            1484.16   
==================================================

Checklist

Copy link
Contributor

Summary of Changes

Hello @yizhang2077, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a foundational feature for supporting Mamba radix cache (v0) within the SGLang system. The core objective is to enhance the efficiency of KV cache management for models incorporating Mamba architectures. This is achieved by implementing a specialized radix tree that intelligently handles both standard and Mamba-specific KV states, allowing for better resource utilization and faster inference. The changes span across memory allocation, request scheduling, and cache eviction policies, culminating in significant performance gains as evidenced by the provided benchmarks.

Highlights

  • Mamba Radix Cache Implementation: Introduced a new MambaRadixCache class to efficiently manage hybrid (full and Mamba) KV cache states, leveraging a radix tree structure for optimized prefix sharing.
  • Memory Management Enhancements: Updated memory_pool.py to support Mamba-specific memory allocation, freeing, and state copying/forking using torch.Tensor for improved GPU compatibility and efficiency.
  • Scheduler and Policy Integration: Modified the scheduling and batch management logic across schedule_batch.py, schedule_policy.py, and scheduler.py to seamlessly integrate the new MambaRadixCache, including mechanisms for Mamba cache eviction and detailed memory usage tracking.
  • Performance Improvements: Benchmarking results demonstrate a notable increase in request throughput (from 2.64 req/s to 3.54 req/s) and input/output token throughput, alongside reduced end-to-end and time-to-first-token latencies when the Mamba radix cache is enabled.
  • Unit Testing: Added comprehensive unit tests in test_mamba_unittest.py to validate the functionality of HybridLinearKVPool, MambaPool, and MambaRadixCache, ensuring correctness of allocation, eviction, and prefix matching.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Mamba radix cache, which is a significant feature enhancement. The implementation is comprehensive, touching upon scheduling, memory management, and the model execution flow. The new MambaRadixCache is well-structured, and unit tests have been added. I've identified a few areas for improvement, including a potential bug in an assertion, a type hint mismatch, and the use of a magic number that should be a constant. Overall, this is a solid contribution.

if self.is_hybrid_gdn:
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
# for mamba cache radix, it need be divided by 3 (magic number now). (yizhang2077)
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size // 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code uses a magic number 3 to divide max_mamba_cache_size. The comment acknowledges this. It's better to define this as a named constant with a clear explanation of why this division is necessary. This improves code readability and maintainability. For example: MAMBA_CACHE_REQS_RATIO = 3 could be defined at the top of the file or in a constants module.

@Swipe4057
Copy link
Contributor

You need to fix the typo and rename the token_msg variables to token_usage_msg:



class MambaRadixCache(BasePrefixCache):
def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it compatible with MTP? EAGLE fix also should be applied to MambaRadixCache.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think maybe we can do it in another PR

@Swipe4057
Copy link
Contributor

During testing, I discovered that the server crashed when token_usage reached more than 0.99.

@yizhang2077
Copy link
Collaborator Author

yizhang2077 commented Oct 5, 2025

During testing, I discovered that the server crashed when token_usage reached more than 0.99.

Do you have reproduce command? I think token_usage > 0.99 is an abnormal state. (It is too large and other models will crash as well in this state)

@Swipe4057
Copy link
Contributor

During testing, I discovered that the server crashed when token_usage reached more than 0.99.

Do you have reproduce command? I think token_usage > 0.99 is an abnormal state. (It is too large and other models will crash as well in this state)

reproduce command (server H100, tp-size=2):
python bench_serving.py
--host 127.0.0.1
--port 30000
--dataset-name generated-shared-prefix
--gsp-system-prompt-len 1000
--gsp-question-prompt-len 2744
--gsp-output-prompt-len 820
--gsp-num-groups 8
--gsp-prompts-per-group 128

The root cause is incorrect memory availability checking in the Mamba pool. Instead of checking the Mamba pool, only the MHA pool is used, which leads to attempted memory allocation in a full Mamba pool and subsequent server crash due to None being returned from mamba_pool.alloc().

Clode Sonnet's recommendations:

  1. Scheduler.check_memory():
    • Check availability of both pools (MHA and Mamba) separately.
  2. PrefillAdder.budget_state():
    • For hybrid models with Mamba, check availability of both pools separately.
  3. ScheduleBatch.alloc_token_slots():
    • For Mamba pool, use req_to_token_pool.mamba_pool.alloc() instead of token_to_kv_pool_allocator.alloc().
  4. MambaRadixCache.match_prefix():
    • Add Mamba pool availability check before allocation.
  5. MambaRadixCache.evict_mamba():
    • Add verification that eviction will free sufficient memory.
  6. Scheduler._add_request_to_queue():
    • Add Mamba pool availability check before _prefetch_kvcache().
  7. HybridReqToTokenPool.alloc():
    • Check Mamba pool availability before allocation.
  8. Add logging when Mamba pool memory is insufficient for diagnostics.

@yizhang2077
Copy link
Collaborator Author

yizhang2077 commented Oct 5, 2025

@Swipe4057 Mamba pool controls memory capability by setting available size to 3x max_running_requests here . We can control mamba_usage to around 0.66 at most during this benchmark. I have tried your benchmark in Qwen3-Next-80B-A3B-Instruct-FP8 in H100 and it did not crash (but I found another bug)

@Swipe4057
Copy link
Contributor

Swipe4057 commented Oct 5, 2025

@Swipe4057 Mamba pool controls memory capability by setting available size to 3x max_running_requests here . We can control mamba_usage to around 0.66 at most during this benchmark. I have tried your benchmark in Qwen3-Next-80B-A3B-Instruct-FP8 in H100 and it did not crash (but I found another bug)

Run the service with the command and try testing again:
environment:
- SGLANG_ENABLE_JIT_DEEPGEMM=1
command:
--model-path /data/models/Qwen3-Next-80B-A3B-Instruct
--served-model-name Qwen3-Next-80B-A3B-Instruct
--cuda-graph-max-bs 512
--cuda-graph-bs 1 2 4 8 16 24 32 40 48 56 64 72 80 88 96 104 112 120 128 136 144 152 160 168 176 184 192 200 208 216 224 232 240 248 256 264 272 280 288 296 304 312 320 328 336 344 352 360 368 376 384 392 400 408 416 424 432 440 448 456 464 472 480 488 496 504 512
--sleep-on-idle
--port 8027
--host 0.0.0.0
--schedule-policy lof
--random-seed 11111
--context-length 131072
--grammar-backend xgrammar
--tool-call-parser qwen25
--enable-metrics
--quantization w8a8_fp8
--allow-auto-truncate
--mamba-ssm-dtype bfloat16
--max-running-requests 1024
--tp-size 2
--ep-size 2
--chunked-prefill-size 16384
--prefill-attention-backend flashinfer
--decode-attention-backend flashinfer
--mem-fraction-static 0.86
--max-running-requests 1024
--api-key 123

Model: https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct

@yizhang2077 yizhang2077 force-pushed the support_mamba_radix_cache branch from 72812db to 67d4e34 Compare October 5, 2025 18:15
@yizhang2077
Copy link
Collaborator Author

yizhang2077 commented Oct 5, 2025

@Swipe4057 I have tried and it is ok. Could you share your server log? (error and mamba_usage in log are important items)

@yizhang2077
Copy link
Collaborator Author

seems like mamba tree cache sanity check is not running, let's add it?

def check_tree_cache(self):
        if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
            self.tree_cache.sanity_check()

open sanity_check

@yizhang2077 yizhang2077 force-pushed the support_mamba_radix_cache branch from 8305248 to 8022cdf Compare October 10, 2025 06:45
@ispobock
Copy link
Collaborator

@yizhang2077 Could you resolve the conflicts? Then we can merge it.

@Swipe4057
Copy link
Contributor

Swipe4057 commented Oct 11, 2025

yizhang2077 hanming-lu ispobock I retested your new code.,
Here's the command to start the server on 2xH100:

--model-path /data/models/Qwen3-Next-80B-A3B-Instruct --served-model-name Qwen3-Next-80B-A3B-Instruct --cuda-graph-max-bs 512 --cuda-graph-bs 1 2 4 8 16 24 32 40 48 56 64 72 80 88 96 104 112 120 128 136 144 152 160 168 176 184 192 200 208 216 224 232 240 248 256 264 272 280 288 296 304 312 320 328 336 344 352 360 368 376 384 392 400 408 416 424 432 440 448 456 464 472 480 488 496 504 512 --sleep-on-idle --port 8027 --host 0.0.0.0 --api-key ${SGLANG_API_KEY} --schedule-policy lof --random-seed 11111 --context-length 131072 --grammar-backend xgrammar --prefill-attention-backend flashinfer --decode-attention-backend flashinfer --tool-call-parser qwen25 --enable-metrics --quantization w8a8_fp8 --allow-auto-truncate --tp-size 2 --ep-size 2 --chunked-prefill-size 16384 --max-running-requests 1024 --mem-fraction-static 0.87 --mamba-ssm-dtype bfloat16 --mamba-full-memory-ratio 4

TEST-1:
reproduce command:
python bench_serving.py
--host 127.0.0.1
--port 30000
--dataset-name generated-shared-prefix
--gsp-system-prompt-len 1000
--gsp-question-prompt-len 2744
--gsp-output-prompt-len 820
--gsp-num-groups 8
--gsp-prompts-per-group 128

results:
main, disable radix:
input_throughput: 18518

mr, enable radix:
input_throughput: 16174 - the results are only getting worse !!!

logs:
Server startup is slow and 503 errors are occurring (this does not happen in main)
image

image

TEST-2:
reproduce command:
python bench_serving.py
--host 127.0.0.1
--port 30000
--dataset-name generated-shared-prefix
--gsp-system-prompt-len 1000
--gsp-question-prompt-len 0 - pay attention to this line and the results !!!
--gsp-output-prompt-len 1 - pay attention to this line and the results !!!
--gsp-num-groups 8
--gsp-prompts-per-group 128

results:
main, disable radix:
input_throughput: 45264

mr, enable radix:
input_throughput: 39668

As you can see, the last test sends 8 groups of identical prompts with a length of 1000 characters, without a variable part and with only 1 token for generation. Performance with radix cache enabled dropped dramatically!
And there is simply no cache match!

@Swipe4057
Copy link
Contributor

Swipe4057 commented Oct 11, 2025

yizhang2077 Try this code to fix null cache hit:

def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
    """Find the matching prefix from the radix tree."""
    cow_mamba: bool = kwargs.get("cow_mamba", False)
    req: Req = kwargs.get("req", None)

    if self.disable or len(key) == 0:
        return MatchResult(
            device_indices=torch.empty(
                (0,),
                dtype=torch.int64,
                device=self.device,
            ),
            last_device_node=self.root_node,
            last_host_node=self.root_node,
        )

    # Get the full matched prefix (including tombstone nodes)
    full_value, full_last_node, mamba_last_node = self._match_prefix_helper(key)

    # copy mamba state to req local space if cow is true
    if cow_mamba and mamba_last_node.mamba_value is not None:
        # for reqs without mamba cache
        if req.mamba_pool_idx is None:
            dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
            # try to alloc again, protect mamba_last_node from eviction
            if dst_index is None:
                self.inc_lock_ref(mamba_last_node)
                self.evict_mamba(1)
                dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
                self.dec_lock_ref(mamba_last_node)
                assert dst_index is not None, "Can not alloc mamba cache"
            src_index = mamba_last_node.mamba_value
            self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)
            req.mamba_pool_idx = dst_index[0]
        else:
            src_index = mamba_last_node.mamba_value
            dst_index = req.mamba_pool_idx.unsqueeze(0)
            self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)

    # Optimize: use torch.cat only if needed
    if full_value:
        value = torch.cat(full_value) if len(full_value) > 1 else full_value[0]
    else:
        value = torch.empty((0,), dtype=torch.int64, device=self.device)

    return MatchResult(
        device_indices=value,
        last_device_node=full_last_node,
        last_host_node=full_last_node,
    )

def _match_prefix_helper(
    self, key: RadixKey
) -> Tuple[List[torch.Tensor], TreeNode, TreeNode]:
    """
    Mamba prefix matching helper. Returns:
    - full_value: all matched full cache tokens (including through tombstone nodes)
    - full_last_node: last node with full cache match
    - mamba_last_node: last node with mamba_value (for COW)
    """
    node = self.root_node
    child_key = self.get_child_key_fn(key)
    value = []
    mamba_last_node = node  # Track last node with mamba state

    while len(key) > 0 and child_key in node.children.keys():
        child = node.children[child_key]

        prefix_len = self.key_match_fn(child.key, key)
        if prefix_len < len(child.key):
            new_node = self._split_node(child.key, child, prefix_len)
            value.append(new_node.value)
            node = new_node
            # Update mamba_last_node if new_node has mamba state
            if new_node.mamba_value is not None:
                mamba_last_node = new_node
            break
        else:
            value.append(child.value)
            # Update mamba_last_node before moving to child
            if child.mamba_value is not None:
                mamba_last_node = child
            node = child
            key = key[prefix_len:]
            if len(key):
                child_key = self.get_child_key_fn(key)

    # update time for matched nodes - single pass
    self.full_lru_list.reset_node_and_parents_mru(node, self.root_node)
    if mamba_last_node != self.root_node and mamba_last_node.mamba_value is not None:
        self.mamba_lru_list.reset_node_and_parents_mru(mamba_last_node, self.root_node)

    return value, node, mamba_last_node

TEST-1:

results:
main, disable radix:
input_throughput: 18518

mr+THIS FIX, enable radix:
input_throughput: 16174 up to 17993

TEST-2:

results:
main, disable radix:
input_throughput: 45264

mr+THIS FIX, enable radix:
input_throughput: 39668 up to 298052 !!!!!!!

@hanming-lu
Copy link
Collaborator

@Swipe4057 let's merge this PR to add mamba radix cache functionality. welcome to make changes to improve performance!

@yizhang2077 yizhang2077 force-pushed the support_mamba_radix_cache branch from 70f8e56 to c4273af Compare October 12, 2025 14:10
@zhyncs zhyncs merged commit a55cf53 into main Oct 13, 2025
127 of 147 checks passed
@zhyncs zhyncs deleted the support_mamba_radix_cache branch October 13, 2025 03:57
@lisp2025
Copy link

lisp2025 commented Oct 13, 2025

/data/install/backup/sglang_main/python/sglang/srt/mem_cache/memory_pool.py", line 306, in alloc
mid = self.mamba_pool.alloc(1)[0]
TypeError: 'NoneType' object is not subscriptable

on 20 requests,--max-mamba-cache-size 100000

ShangmingCai added a commit to kvcache-ai/sglang that referenced this pull request Oct 13, 2025
Signed-off-by: Shangming Cai <[email protected]>
cctry pushed a commit that referenced this pull request Oct 13, 2025
Co-authored-by: hanming-lu <[email protected]>
Co-authored-by: hzh0425 <[email protected]>
Co-authored-by: thalahors <[email protected]>
Comment on lines +385 to 388
if config := self.mamba2_config:
class_name = config.__class__.__name__
logger.warning(f"{class_name} model detected, disable radix cache")
self.server_args.disable_radix_cache = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, we disabled radix tree?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants